Skip to content

[CUDA] Implement BlockMaskedMM#3299

Merged
zcbenz merged 9 commits intoml-explore:mainfrom
Lyxot:cuda/block-masked-mm
Mar 26, 2026
Merged

[CUDA] Implement BlockMaskedMM#3299
zcbenz merged 9 commits intoml-explore:mainfrom
Lyxot:cuda/block-masked-mm

Conversation

@Lyxot
Copy link
Copy Markdown
Contributor

@Lyxot Lyxot commented Mar 23, 2026

Proposed changes

Implement mx.block_masked_mm for the CUDA backend

Performance

Compared against naive MLX expand + matmul

float32:

Case (MxNxKxBS) Sparsity MLX ms Naive ms Speedup
256x256x256x32 50% 0.036 0.060 1.65x
512x512x512x32 50% 0.053 0.076 1.43x
1024x1024x1024x32 50% 0.141 0.169 1.20x
2048x2048x2048x64 50% 0.751 1.040 1.38x
4096x4096x4096x32 50% 4.510 5.860 1.30x
8192x8192x8192x64 50% 30.916 36.695 1.19x
4096x4096x11008x32 50% 11.290 14.249 1.26x

float16:

Case (MxNxKxBS) Sparsity MLX ms Naive ms Speedup
256x256x256x32 50% 0.037 0.056 1.52x
512x512x512x32 50% 0.043 0.066 1.52x
1024x1024x1024x32 50% 0.097 0.106 1.09x
2048x2048x2048x64 50% 0.387 0.443 1.14x
4096x4096x4096x32 50% 2.359 2.992 1.27x
8192x8192x8192x64 50% 16.493 19.072 1.16x

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works

Lyxot added 3 commits March 23, 2026 19:48
Add CUDA implementation for block-masked matrix multiplication.
The approach pre-masks input matrices with a simple CUDA kernel,
calls cuBLAS GEMM, then applies the output mask.
Replace the two-pass approach (contiguous_copy_gpu + apply_block_mask)
with a single copy_with_block_mask kernel that reads source data and
applies the mask in one pass.
Copilot AI review requested due to automatic review settings March 23, 2026 14:25
@Lyxot Lyxot force-pushed the cuda/block-masked-mm branch from 8956bdb to b258254 Compare March 23, 2026 14:26
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements the mx.block_masked_mm primitive for the CUDA backend, enabling block/tile-masked matrix multiplication on NVIDIA GPUs and unskipping the corresponding CUDA tests.

Changes:

  • Add CUDA BlockMaskedMM::eval_gpu implementation wired into the CUDA matmul path.
  • Introduce CUDA kernels/helpers to apply and fuse-copy block masks (apply_block_mask, copy_with_block_mask).
  • Enable CUDA CI coverage for test_block_masked_matmul and add a Python benchmark script.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
python/tests/cuda_skip.py Unskips the block-masked matmul test on CUDA.
mlx/backend/cuda/primitives.cpp Removes the “no CUDA implementation” stub for BlockMaskedMM.
mlx/backend/cuda/matmul.cpp Adds BlockMaskedMM::eval_gpu and integrates mask-copy/mask-apply around GEMM.
mlx/backend/cuda/gemms/block_mask.h Declares CUDA block-mask helper APIs.
mlx/backend/cuda/gemms/block_mask.cu Implements CUDA kernels for masked copy and in-place output masking.
mlx/backend/cuda/CMakeLists.txt Adds gemms/block_mask.cu to the CUDA build.
benchmarks/python/block_masked_mm_bench.py Adds a benchmark + optional correctness check vs naive expand+matmul.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

@zcbenz zcbenz merged commit 0ff1115 into ml-explore:main Mar 26, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants